{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "uVoMtekMk-8i"
   },
   "source": [
    "# Finetuner un modèle de language masqué (PyTorch)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "97UMn_IHk-8n"
   },
   "source": [
    "Installez les bibliothèques 🤗 *Datasets*,  🤗 *Transformers* et 🤗 *Accelerate* pour exécuter ce *notebook*."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "JTP7moZJk-8r"
   },
   "outputs": [],
   "source": [
    "!pip install datasets transformers[sentencepiece]\n",
    "!pip install accelerate\n",
    "# Pour exécuter l'entraînement sur TPU, vous devez décommenter la ligne suivante :\n",
    "# !pip install cloud-tpu-client==0.10 torch==1.9.0 https://storage.googleapis.com/tpu-pytorch/wheels/torch_xla-1.9-cp37-cp37m-linux_x86_64.whl\n",
    "!apt install git-lfs"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "W8b-VVo2k-8w"
   },
   "source": [
    "Vous aurez besoin de configurer git, adaptez votre email et votre nom dans la cellule suivante."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "nKcOiNkCk-8z"
   },
   "outputs": [],
   "source": [
    "!git config --global user.email \"you@example.com\"\n",
    "!git config --global user.name \"Your Name\""
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "kDYcdgKqk-81"
   },
   "source": [
    "Vous devrez également être connecté au Hub d'Hugging Face. Exécutez ce qui suit et entrez vos informations d'identification."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "FaqLiMYJk-85"
   },
   "outputs": [],
   "source": [
    "from huggingface_hub import notebook_login\n",
    "\n",
    "notebook_login()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "bnwlqB03k-87"
   },
   "outputs": [],
   "source": [
    "from transformers import AutoModelForMaskedLM\n",
    "\n",
    "model_checkpoint = \"camembert-base\"\n",
    "model = AutoModelForMaskedLM.from_pretrained(model_checkpoint)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "bREfPqNMk-89"
   },
   "outputs": [],
   "source": [
    "text = \"C'est une grande <mask>.\""
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "5EjKRk9Bk-9C"
   },
   "outputs": [],
   "source": [
    "from transformers import AutoTokenizer\n",
    "\n",
    "tokenizer = AutoTokenizer.from_pretrained(model_checkpoint)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "15KV-A67k-9E"
   },
   "outputs": [],
   "source": [
    "import torch\n",
    "\n",
    "inputs = tokenizer(text, return_tensors=\"pt\")\n",
    "token_logits = model(**inputs).logits\n",
    "# Trouver l'emplacement du <mask> et extraire ses logits\n",
    "mask_token_index = torch.where(inputs[\"input_ids\"] == tokenizer.mask_token_id)[1]\n",
    "mask_token_logits = token_logits[0, mask_token_index, :]\n",
    "# Choisir les <mask> candidats avec les logits les plus élevés\n",
    "top_5_tokens = torch.topk(mask_token_logits, 5, dim=1).indices[0].tolist()\n",
    "\n",
    "for token in top_5_tokens:\n",
    "    print(f\"'>>> {text.replace(tokenizer.mask_token, tokenizer.decode([token]))}'\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "lQuMuSuwk-9J"
   },
   "outputs": [],
   "source": [
    "from datasets import load_dataset\n",
    "\n",
    "imdb_dataset = load_dataset(\"allocine\")\n",
    "imdb_dataset"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "svNO5I9lk-9M"
   },
   "outputs": [],
   "source": [
    "sample = imdb_dataset[\"train\"].shuffle(seed=42).select(range(3))\n",
    "\n",
    "for row in sample:\n",
    "    print(f\"\\n'>>> Review: {row['review']}'\")\n",
    "    print(f\"'>>> Label: {row['label']}'\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "tm0x7w6Gk-9P"
   },
   "outputs": [],
   "source": [
    "def tokenize_function(examples):\n",
    "    result = tokenizer(examples[\"review\"])\n",
    "    if tokenizer.is_fast:\n",
    "        result[\"word_ids\"] = [result.word_ids(i) for i in range(len(result[\"input_ids\"]))]\n",
    "    return result\n",
    "\n",
    "\n",
    "# Utilisez batched=True pour activer le multithreading rapide !\n",
    "tokenized_datasets = imdb_dataset.map(\n",
    "    tokenize_function, batched=True, remove_columns=[\"review\", \"label\"]\n",
    ")\n",
    "tokenized_datasets"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "5zJvlXd2k-9R"
   },
   "outputs": [],
   "source": [
    "tokenizer.model_max_length"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "tk3TfASAk-9T"
   },
   "outputs": [],
   "source": [
    "chunk_size = 128"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "y0ZMbjwQk-9U"
   },
   "outputs": [],
   "source": [
    "# Le découpage produit une liste de listes pour chaque caractéristique\n",
    "tokenized_samples = tokenized_datasets[\"train\"][:3]\n",
    "\n",
    "for idx, sample in enumerate(tokenized_samples[\"input_ids\"]):\n",
    "    print(f\"'>>> Review {idx} length: {len(sample)}'\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "Mg2PlCl8k-9X"
   },
   "outputs": [],
   "source": [
    "concatenated_examples = {\n",
    "    k: sum(tokenized_samples[k], []) for k in tokenized_samples.keys()\n",
    "}\n",
    "total_length = len(concatenated_examples[\"input_ids\"])\n",
    "print(f\"'>>> Concatenated reviews length: {total_length}'\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "BU_uR_dLk-9Y"
   },
   "outputs": [],
   "source": [
    "chunks = {\n",
    "    k: [t[i : i + chunk_size] for i in range(0, total_length, chunk_size)]\n",
    "    for k, t in concatenated_examples.items()\n",
    "}\n",
    "\n",
    "for chunk in chunks[\"input_ids\"]:\n",
    "    print(f\"'>>> Chunk length: {len(chunk)}'\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "Y9Fwylomk-9b"
   },
   "outputs": [],
   "source": [
    "def group_texts(examples):\n",
    "    # Concaténation de tous les textes\n",
    "    concatenated_examples = {k: sum(examples[k], []) for k in examples.keys()}\n",
    "    # Calculer la longueur des textes concaténés\n",
    "    total_length = len(concatenated_examples[list(examples.keys())[0]])\n",
    "    # Nous laissons tomber le dernier morceau s'il est plus petit que chunk_size\n",
    "    total_length = (total_length // chunk_size) * chunk_size\n",
    "    # Fractionnement par morceaux de max_len\n",
    "    result = {\n",
    "        k: [t[i : i + chunk_size] for i in range(0, total_length, chunk_size)]\n",
    "        for k, t in concatenated_examples.items()\n",
    "    }\n",
    "    # Créer une nouvelle colonne d'étiquettes\n",
    "    result[\"labels\"] = result[\"input_ids\"].copy()\n",
    "    return result"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "0Ez61wwAk-9e"
   },
   "outputs": [],
   "source": [
    "lm_datasets = tokenized_datasets.map(group_texts, batched=True)\n",
    "lm_datasets"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "TPeyRGujk-9f"
   },
   "outputs": [],
   "source": [
    "tokenizer.decode(lm_datasets[\"train\"][1][\"input_ids\"])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "8Qz1pwNAk-9h"
   },
   "outputs": [],
   "source": [
    "from transformers import DataCollatorForLanguageModeling\n",
    "\n",
    "data_collator = DataCollatorForLanguageModeling(tokenizer=tokenizer, mlm_probability=0.15)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "sQIG43glk-9j"
   },
   "outputs": [],
   "source": [
    "samples = [lm_datasets[\"train\"][i] for i in range(2)]\n",
    "for sample in samples:\n",
    "    _ = sample.pop(\"word_ids\")\n",
    "\n",
    "for chunk in data_collator(samples)[\"input_ids\"]:\n",
    "    print(f\"\\n'>>> {tokenizer.decode(chunk)}'\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "YHsPa9_Uk-9k"
   },
   "outputs": [],
   "source": [
    "import collections\n",
    "import numpy as np\n",
    "\n",
    "from transformers import default_data_collator\n",
    "\n",
    "wwm_probability = 0.2\n",
    "\n",
    "\n",
    "def whole_word_masking_data_collator(features):\n",
    "    for feature in features:\n",
    "        word_ids = feature.pop(\"word_ids\")\n",
    "\n",
    "        # Création d'une correspondance entre les mots et les indices des tokens correspondants\n",
    "        mapping = collections.defaultdict(list)\n",
    "        current_word_index = -1\n",
    "        current_word = None\n",
    "        for idx, word_id in enumerate(word_ids):\n",
    "            if word_id is not None:\n",
    "                if word_id != current_word:\n",
    "                    current_word = word_id\n",
    "                    current_word_index += 1\n",
    "                mapping[current_word_index].append(idx)\n",
    "\n",
    "        # Masquer des mots de façon aléatoire\n",
    "        mask = np.random.binomial(1, wwm_probability, (len(mapping),))\n",
    "        input_ids = feature[\"input_ids\"]\n",
    "        labels = feature[\"labels\"]\n",
    "        new_labels = [-100] * len(labels)\n",
    "        for word_id in np.where(mask)[0]:\n",
    "            word_id = word_id.item()\n",
    "            for idx in mapping[word_id]:\n",
    "                new_labels[idx] = labels[idx]\n",
    "                input_ids[idx] = tokenizer.mask_token_id\n",
    "        feature[\"labels\"] = new_labels\n",
    "\n",
    "    return default_data_collator(features)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "p3CVF4egk-9n"
   },
   "outputs": [],
   "source": [
    "samples = [lm_datasets[\"train\"][i] for i in range(2)]\n",
    "batch = whole_word_masking_data_collator(samples)\n",
    "\n",
    "for chunk in batch[\"input_ids\"]:\n",
    "    print(f\"\\n'>>> {tokenizer.decode(chunk)}'\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "Z2Ez4ONCk-9o"
   },
   "outputs": [],
   "source": [
    "train_size = 10_000\n",
    "test_size = int(0.1 * train_size)\n",
    "\n",
    "downsampled_dataset = lm_datasets[\"train\"].train_test_split(\n",
    "    train_size=train_size, test_size=test_size, seed=42\n",
    ")\n",
    "downsampled_dataset"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "IUO-qIllk-9r"
   },
   "outputs": [],
   "source": [
    "from transformers import TrainingArguments\n",
    "\n",
    "batch_size = 64\n",
    "# Montrer la perte d'entraînement à chaque époque\n",
    "logging_steps = len(downsampled_dataset[\"train\"]) // batch_size\n",
    "model_name = model_checkpoint.split(\"/\")[-1]\n",
    "\n",
    "training_args = TrainingArguments(\n",
    "    output_dir=f\"{model_name}-finetuned-allocine\",\n",
    "    overwrite_output_dir=True,\n",
    "    evaluation_strategy=\"epoch\",\n",
    "    learning_rate=2e-5,\n",
    "    weight_decay=0.01,\n",
    "    per_device_train_batch_size=batch_size,\n",
    "    per_device_eval_batch_size=batch_size,\n",
    "    push_to_hub=True,\n",
    "    fp16=True,\n",
    "    logging_steps=logging_steps,\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "ipKYzt6sk-9t"
   },
   "outputs": [],
   "source": [
    "from transformers import Trainer\n",
    "\n",
    "trainer = Trainer(\n",
    "    model=model,\n",
    "    args=training_args,\n",
    "    train_dataset=downsampled_dataset[\"train\"],\n",
    "    eval_dataset=downsampled_dataset[\"test\"],\n",
    "    data_collator=data_collator,\n",
    "    tokenizer=tokenizer,\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "ruFr2J1sk-9u"
   },
   "outputs": [],
   "source": [
    "import math\n",
    "\n",
    "eval_results = trainer.evaluate()\n",
    "print(f\">>> Perplexity: {math.exp(eval_results['eval_loss']):.2f}\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "-fuBwvqek-9w"
   },
   "outputs": [],
   "source": [
    "trainer.train()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "22r8rDYTk-9x"
   },
   "outputs": [],
   "source": [
    "eval_results = trainer.evaluate()\n",
    "print(f\">>> Perplexity: {math.exp(eval_results['eval_loss']):.2f}\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "cv0CwVamk-9z"
   },
   "outputs": [],
   "source": [
    "trainer.push_to_hub()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "-GRfjc37k-91"
   },
   "outputs": [],
   "source": [
    "def insert_random_mask(batch):\n",
    "    features = [dict(zip(batch, t)) for t in zip(*batch.values())]\n",
    "    masked_inputs = data_collator(features)\n",
    "    # Créer une nouvelle colonne \"masquée\" pour chaque colonne du jeu de données\n",
    "    return {\"masked_\" + k: v.numpy() for k, v in masked_inputs.items()}"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "RZbvYGpBk-91"
   },
   "outputs": [],
   "source": [
    "downsampled_dataset = downsampled_dataset.remove_columns([\"word_ids\"])\n",
    "eval_dataset = downsampled_dataset[\"test\"].map(\n",
    "    insert_random_mask,\n",
    "    batched=True,\n",
    "    remove_columns=downsampled_dataset[\"test\"].column_names,\n",
    ")\n",
    "eval_dataset = eval_dataset.rename_columns(\n",
    "    {\n",
    "        \"masked_input_ids\": \"input_ids\",\n",
    "        \"masked_attention_mask\": \"attention_mask\",\n",
    "        \"masked_labels\": \"labels\",\n",
    "    }\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "B953lLCEk-93"
   },
   "outputs": [],
   "source": [
    "from torch.utils.data import DataLoader\n",
    "from transformers import default_data_collator\n",
    "\n",
    "batch_size = 64\n",
    "train_dataloader = DataLoader(\n",
    "    downsampled_dataset[\"train\"],\n",
    "    shuffle=True,\n",
    "    batch_size=batch_size,\n",
    "    collate_fn=data_collator,\n",
    ")\n",
    "eval_dataloader = DataLoader(\n",
    "    eval_dataset, batch_size=batch_size, collate_fn=default_data_collator\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "qk987sI3k-95"
   },
   "outputs": [],
   "source": [
    "from torch.optim import AdamW\n",
    "\n",
    "optimizer = AdamW(model.parameters(), lr=5e-5)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "mgKQR4Aok-96"
   },
   "outputs": [],
   "source": [
    "from accelerate import Accelerator\n",
    "\n",
    "accelerator = Accelerator()\n",
    "model, optimizer, train_dataloader, eval_dataloader = accelerator.prepare(\n",
    "    model, optimizer, train_dataloader, eval_dataloader\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "p7iQC9nZk-97"
   },
   "outputs": [],
   "source": [
    "from transformers import get_scheduler\n",
    "\n",
    "num_train_epochs = 3\n",
    "num_update_steps_per_epoch = len(train_dataloader)\n",
    "num_training_steps = num_train_epochs * num_update_steps_per_epoch\n",
    "\n",
    "lr_scheduler = get_scheduler(\n",
    "    \"linear\",\n",
    "    optimizer=optimizer,\n",
    "    num_warmup_steps=0,\n",
    "    num_training_steps=num_training_steps,\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "3b8pk5NRk-98"
   },
   "outputs": [],
   "source": [
    "from huggingface_hub import get_full_repo_name\n",
    "\n",
    "model_name = \"camembert-base-finetuned-allocine-accelerate\"\n",
    "repo_name = get_full_repo_name(model_name)\n",
    "repo_name"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "0dPTV7r6k-9-"
   },
   "outputs": [],
   "source": [
    "from huggingface_hub import Repository\n",
    "\n",
    "output_dir = model_name\n",
    "repo = Repository(output_dir, clone_from=repo_name)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "mekFif4nk-9_"
   },
   "outputs": [],
   "source": [
    "from tqdm.auto import tqdm\n",
    "import torch\n",
    "import math\n",
    "\n",
    "progress_bar = tqdm(range(num_training_steps))\n",
    "\n",
    "for epoch in range(num_train_epochs):\n",
    "    # Entraînement\n",
    "    model.train()\n",
    "    for batch in train_dataloader:\n",
    "        outputs = model(**batch)\n",
    "        loss = outputs.loss\n",
    "        accelerator.backward(loss)\n",
    "\n",
    "        optimizer.step()\n",
    "        lr_scheduler.step()\n",
    "        optimizer.zero_grad()\n",
    "        progress_bar.update(1)\n",
    "\n",
    "    # Evaluation\n",
    "    model.eval()\n",
    "    losses = []\n",
    "    for step, batch in enumerate(eval_dataloader):\n",
    "        with torch.no_grad():\n",
    "            outputs = model(**batch)\n",
    "\n",
    "        loss = outputs.loss\n",
    "        losses.append(accelerator.gather(loss.repeat(batch_size)))\n",
    "\n",
    "    losses = torch.cat(losses)\n",
    "    losses = losses[: len(eval_dataset)]\n",
    "    try:\n",
    "        perplexity = math.exp(torch.mean(losses))\n",
    "    except OverflowError:\n",
    "        perplexity = float(\"inf\")\n",
    "\n",
    "    print(f\">>> Epoch {epoch}: Perplexity: {perplexity}\")\n",
    "\n",
    "    # Sauvegarder et télécharger\n",
    "    accelerator.wait_for_everyone()\n",
    "    unwrapped_model = accelerator.unwrap_model(model)\n",
    "    unwrapped_model.save_pretrained(output_dir, save_function=accelerator.save)\n",
    "    if accelerator.is_main_process:\n",
    "        tokenizer.save_pretrained(output_dir)\n",
    "        repo.push_to_hub(\n",
    "            commit_message=f\"Training in progress epoch {epoch}\", blocking=False\n",
    "        )"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "qpGfXsN3k--B"
   },
   "outputs": [],
   "source": [
    "from transformers import pipeline\n",
    "\n",
    "mask_filler = pipeline(\n",
    "    \"fill-mask\", model=\"huggingface-course/camembert-base-finetuned-allocine\", tokenizer=\"huggingface-course/camembert-base-finetuned-allocine\",\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "q0EDeIUJk--C"
   },
   "outputs": [],
   "source": [
    "preds = mask_filler(text)\n",
    "\n",
    "for pred in preds:\n",
    "    print(f\">>> {pred['sequence']}\")"
   ]
  }
 ],
 "metadata": {
  "accelerator": "GPU",
  "colab": {
   "collapsed_sections": [],
   "provenance": []
  },
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.15 (main, Oct 11 2022, 22:27:25) \n[Clang 14.0.0 (clang-1400.0.29.102)]"
  },
  "vscode": {
   "interpreter": {
    "hash": "397704579725e15f5c7cb49fe5f0341eb7531c82d19f2c29d197e8b64ab5776b"
   }
  }
 },
 "nbformat": 4,
 "nbformat_minor": 1
}
